import dalex as dx
import pandas as pd
import pickle
import xgboost as xgb
from sklearn.model_selection import train_test_split
input_df = pd.read_csv('new_preprocessed_dataset.csv')
y = input_df.loc[:,'Attrition']
X = input_df.drop('Attrition', axis='columns')
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=14)
path = '../../../Projekt/Modele/BarteKasiAdam/new_xgb_model.p'
model = pickle.load(open( path, "rb" ))
explainer = dx.Explainer(model, X_train, y_train)
Preparation of a new explainer is initiated -> data : 7595 rows 21 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 7595 values -> model_class : xgboost.sklearn.XGBClassifier (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x7fa971ad1820> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 1.49e-06, mean = 0.159, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.699, mean = 9.19e-05, max = 0.82 -> model_info : package xgboost A new explainer has been created!
Dla obserwacji 32. model przewiduje z prawie 100% prawdopodobieństwem, że klient nie zrezygnuje z usługi. W tym przypadku predykcja okazała się prawdziwa.
index = 32
display(explainer.predict(X_test)[index])
display(y_test.iloc[index])
0.00017104982
0
cp = explainer.predict_profile(X_test.iloc[index,:])
Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 103.05it/s]
cp
| Gender | Card_Category | Missing_Income | Income_Category | Missing_Education | Education_Level | Marital_Divorced | Marital_Married | Marital_Single | Marital_Unknown | ... | Months_Inactive_12_mon | Contacts_Count_12_mon | Total_Revolving_Bal | Total_Amt_Chng_Q4_Q1 | Total_Ct_Chng_Q4_Q1 | _original_ | _yhat_ | _vname_ | _ids_ | _label_ | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 20 | 0.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 0.929 | 1.000 | 0.000179 | Gender | 20 | XGBClassifier |
| 20 | 0.01 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 0.929 | 1.000 | 0.000179 | Gender | 20 | XGBClassifier |
| 20 | 0.02 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 0.929 | 1.000 | 0.000179 | Gender | 20 | XGBClassifier |
| 20 | 0.03 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 0.929 | 1.000 | 0.000179 | Gender | 20 | XGBClassifier |
| 20 | 0.04 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 0.929 | 1.000 | 0.000179 | Gender | 20 | XGBClassifier |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 20 | 1.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 3.360 | 0.929 | 0.000108 | Total_Ct_Chng_Q4_Q1 | 20 | XGBClassifier |
| 20 | 1.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 3.395 | 0.929 | 0.000108 | Total_Ct_Chng_Q4_Q1 | 20 | XGBClassifier |
| 20 | 1.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 3.430 | 0.929 | 0.000108 | Total_Ct_Chng_Q4_Q1 | 20 | XGBClassifier |
| 20 | 1.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 3.465 | 0.929 | 0.000108 | Total_Ct_Chng_Q4_Q1 | 20 | XGBClassifier |
| 20 | 1.00 | 1.0 | 0.0 | 4.0 | 0.0 | 6.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 2.0 | 0.0 | 1800.0 | 0.906 | 3.500 | 0.929 | 0.000108 | Total_Ct_Chng_Q4_Q1 | 20 | XGBClassifier |
2129 rows × 26 columns
cp.plot(facet_ncol=5)
Dla tej obserwacji zmiana wartości prawie każdej ze zmiennych nie zmienia predykcji modelu. Jedynie zmianaTotal_Trans_Amt (całkowita wartość transakcji) lub Contacts_Count_12_mon (liczba kontaktów z przedstawicielami banku w ostatnim roku) wpływa na zmianę wyjściowego prawdopodobienstwa.
Gdyby zmienna Total_Tran_Amt była <= 860 lub w przedziale <6 000,11 000> model przewidywałby odwrotny wynik, mianowicie odejście klienta.
Natomiast gdyby Contacts_Count_12_mon była większa od 5, to prawdopodobieństwo odejścia klienta zwiększyłoby się o około 6%.
explainer_b = dx.Explainer(model, X_train, y_train, label='index_b')
Preparation of a new explainer is initiated -> data : 7595 rows 21 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 7595 values -> model_class : xgboost.sklearn.XGBClassifier (default) -> label : index_b -> predict function : <function yhat_proba_default at 0x7fa971ad1820> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 1.49e-06, mean = 0.159, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.699, mean = 9.19e-05, max = 0.82 -> model_info : package xgboost A new explainer has been created!
Utworzono nowy obiekt explainer, aby wykresy były w różnych kolorach.
index_a = 5274
index_b = 7997
#12,123; 57,123; 123,8425; 7409,7997; 5274,7997
cp_a = explainer.predict_profile(X_test.loc[index_a,:])
cp_b = explainer_b.predict_profile(X_test.loc[index_b,:])
cp_a.plot(cp_b, facet_ncol=6)
Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 52.13it/s] Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 62.11it/s]
Dla jednej z powyższych obserwacji wraz ze wzrostem Customer_Age (wiek klienta) prawdopodobieństwo odejścia rośnie, a dla drugiej z tych obserwacji sytuacja jest odwrotna.
Wpływ płci (Gender) na predykcję jest inny dla tych obserwacji. Na jedną nie wpływa zupełnie, natomiast dla drugiej zmienia wyjściowe prawdopodobieństwo. A konkretniej w rzeczywistości klient jest kobietą, a gdyby był mężczyzną to wynik zmieniłby się o 0.2.